#!/usr/bin/env python

# (c) 2016 Matt Clay <matt@mystile.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function

import os
import subprocess
import sys

from os import path
from argparse import ArgumentParser

import ansible.constants as C

from ansible.playbook import Playbook
from ansible.vars import VariableManager
from ansible.parsing.dataloader import DataLoader


def main():
    """Generate an integration test script for changed modules."""

    C.DEPRECATION_WARNINGS = False

    posix_targets = [
        'non_destructive',
        'destructive',
    ]

    windows_targets = [
        'test_win_group1',
        'test_win_group2',
        'test_win_group3',
    ]

    parser = ArgumentParser(description='Generate an integration test script for changed modules.')
    parser.add_argument('module_group', choices=['core', 'extras'], help='module group to test')
    parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='write verbose output to stderr')
    parser.add_argument('--changes', dest='changes', default=None,
                        help='file listing changed paths (default: query git)')
    parser.add_argument('--image', dest='image', default=os.environ.get('IMAGE'),
                        help='image to run tests with')
    parser.add_argument('--privileged', dest='privileged', action='store_true',
                        default=os.environ.get('PRIVILEGED') == 'true',
                        help='run container in privileged mode')
    parser.add_argument('--python3', dest='python3', action='store_true',
                        default=os.environ.get('PYTHON3', '') != '',
                        help='run tests using python3')
    parser.add_argument('--platform', dest='platform', default=os.environ.get('PLATFORM'),
                        help='platform to run tests on')
    parser.add_argument('--version', dest='version', default=os.environ.get('VERSION'),
                        help='version of platform to run tests on')
    parser.add_argument('--output', dest='output', required=True,
                        help='path to write output script to')

    args = parser.parse_args()

    targets = posix_targets

    if args.image is not None:
        script = 'integration'
        options = ''
        if args.privileged:
            options += ' PRIVILEGED=true'
        if args.python3:
            options += ' PYTHON3=1'
        jobs = ['IMAGE=%s%s' % (args.image, options)]
    elif args.platform is not None and args.version is not None:
        script = 'remote'
        jobs = ['PLATFORM=%s VERSION=%s' % (args.platform, args.version)]

        if args.platform == 'windows':
            targets = windows_targets
    else:
        raise Exception('job parameters not specified')

    generate_test_commands(args.module_group, targets, script, args.output, jobs=jobs, verbose=args.verbose, changes=args.changes)


def generate_test_commands(module_group, targets, script, output, jobs=None, verbose=False, changes=None):
    """Generate test commands for the given module group and test targets.

    Args:
        module_group: The module group (core, extras) to examine.
        targets: The test targets to examine.
        script: The script used to execute the test targets.
        output: The path to write the output script to.
        jobs: The test jobs to execute, or None to auto-detect.
        verbose: True to write detailed output to stderr.
        changes: Path to file containing list of changed files, or None to query git.
    """

    base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', '..'))

    job_config_path = path.join(base_dir, 'shippable.yml')
    module_dir = os.path.join(base_dir, 'lib', 'ansible', 'modules', module_group)

    if verbose:
        print_stderr(' config: %s' % job_config_path)
        print_stderr('modules: %s' % module_dir)
        print_stderr('targets: %s' % ' '.join(targets))
        print_stderr()

    if changes is None:
        paths_changed = get_changed_paths(module_dir)
    else:
        with open(changes, 'r') as f:
            paths_changed = f.read().strip().split('\n')

    if len(paths_changed) == 0:
        print_stderr('No changes to files detected.')
        exit()

    if verbose:
        dump_stderr('paths_changed', paths_changed)

    modules_changed = get_modules(paths_changed)

    if len(modules_changed) == 0:
        print_stderr('No changes to modules detected.')
        exit()

    if verbose:
        dump_stderr('modules_changed', modules_changed)

    module_tags = get_module_test_tags(modules_changed)

    if verbose:
        dump_stderr('module_tags', module_tags)

    available_tags = get_target_tags(base_dir, targets)

    if verbose:
        dump_stderr('available_tags', available_tags)

    use_tags = module_tags & available_tags

    if len(use_tags) == 0:
        print_stderr('No tagged test roles found for changed modules.')
        exit()

    if verbose:
        dump_stderr('use_tags', use_tags)

    target = ' '.join(targets)
    tags = ','.join(use_tags)
    script_path = 'test/utils/shippable/%s.sh' % script

    commands = ['TARGET="%s" TEST_FLAGS="-t %s" %s %s' % (target, tags, j, script_path) for j in jobs]

    with open(output, 'w') as f:
        f.writelines(commands)
        f.write('\n')


def print_stderr(*args, **kwargs):
    """Print to stderr."""

    print(*args, file=sys.stderr, **kwargs)
    sys.stderr.flush()


def dump_stderr(label, l):
    """Write a label and list contents to stderr.

    Args:
        label: The label to print for this list.
        l: The list to dump to stderr.
    """

    print_stderr('[%s:%s]\n%s\n' % (label, len(l), '\n'.join(l)))


def get_target_tags(base_dir, targets):
    """Get role tags from the integration tests for the given test targets.

    Args:
        base_dir: The root of the ansible source code.
        targets: The test targets to scan for tags.

    Returns: Set of role tags.
    """

    playbook_dir = os.path.join(base_dir, 'test', 'integration')

    tags = set()

    for target in targets:
        playbook_path = os.path.join(playbook_dir, target + '.yml')
        tags |= get_role_tags(playbook_path)

    return tags


def get_role_tags(playbook_path):
    """Get role tags from the given playbook.

    Args:
        playbook_path: Path to the playbook to get role tags from.

    Returns: Set of role tags.
    """

    variable_manager = VariableManager()
    loader = DataLoader()
    playbook = Playbook.load(playbook_path, variable_manager=variable_manager, loader=loader)
    tags = set()

    for play in playbook.get_plays():
        for role in play.get_roles():
            for tag in role.tags:
                tags.add(tag)

    return tags


def get_changed_paths(git_root, branch='origin/stable-2.2'):
    """Get file paths changed in current branch vs given branch.

    Args:
        git_root: The root of the git clone.
        branch: The branch to compare against (default: origin/stable-2.2)

    Returns: List of file paths changed.
    """

    paths = subprocess.check_output(['git', 'diff', '--name-only', branch], cwd=git_root).strip().split('\n')

    return paths


def get_modules(paths):
    """Get module names from file paths.

    Args:
        paths: List of paths to extract module names from.

    Returns: List of module names.
    """

    module_extensions = [
        '.py',
        '.ps1',
    ]

    modules = [path.splitext(path.basename(c))[0].strip('_') for c in paths if
               path.splitext(c)[1] in module_extensions and
               '/' in c and
               not c.startswith('test/') and
               not path.basename(c)[0] == '__init__.py']

    return modules


def get_module_test_tags(modules):
    """Get test tags from module names.

    Args:
        modules: List of module names to get test tags for.

    Returns: Set of test tags.
    """

    tags = set(['test_' + m for m in modules])
    return tags


if __name__ == '__main__':
    main()
